from backtrader.feeds.binancedata import BinanceFinanceCSVData
from backtrader.sizers import AllInSizer
import backtrader as bt
from backtrader.utils import logger, datehelper
from datetime import datetime, timedelta, time
from backtrader.plot.show import ShowPlot


class BackRun:
    def __init__(self):
        self.logger = logger.getLogger(__name__)

    def run(self, **kwargs):
        for k, v in kwargs.items():
            self.logger.info(f"{k}={v}")

        if "runonce" in kwargs:
            runonce = kwargs["runonce"]
        else:
            runonce = True
        # Create a cerebro
        cerebro = bt.Cerebro(optdatas=False, optreturn=False, runonce=runonce)

        cerebro.broker.set_cash(kwargs["init_cash"])
        cerebro.broker.set_shortcash(True)
        if "slippage" in kwargs and kwargs["slippage"]:
            if kwargs["slippage"]["type"] == "fixed":
                cerebro.broker.set_slippage_fixed(kwargs["slippage"]["value"], slip_open=True, slip_limit=True, slip_match=True, slip_out=False)
            if kwargs["slippage"]["type"] == "perc":
                cerebro.broker.set_slippage_perc(kwargs["slippage"]["value"], slip_open=True, slip_limit=True, slip_match=True, slip_out=False)
        # cerebro.broker.set_slippage_perc(perc=0.001)
        leverage = kwargs["leverage"] if "leverage" in kwargs else 1
        cerebro.broker.setcommission(commission=kwargs["comm"], stocklike=True, leverage=leverage)
        cerebro.addsizer(AllInSizer, percents=kwargs["size"])

        # Get the dates from the args
        fromdate = datetime.strptime(kwargs["start_date"], '%Y-%m-%d')
        todate = datetime.strptime(kwargs["ended_date"], '%Y-%m-%d')

        tfs = kwargs["tf"]
        tfs = [tfs] if isinstance(tfs, str) else tfs

        asset_type = kwargs["type"] if "type" in kwargs else "spot"

        ws = kwargs["workspace"] if "workspace" in kwargs else "/Users/wudi/Workspace/fund/workspace"
        symbols = kwargs["symbol"]

        for tf in tfs:
            for symbol in symbols:
                timeframe, compression, total_minutes = bt.utils.datehelper.parse_timeframe(tf)

                _symbol_name = symbol.replace("/", "_")

                if tf == "1d":
                    sessionstart, sessionend = time(7, 59, 59, 999990), time(8, 0, 0, 0)
                else:
                    sessionstart, sessionend = None, None

                data = BinanceFinanceCSVData(
                    dataname=f"{ws}/data/binance/market/{asset_type}/{tf}/{_symbol_name}.csv",
                    name=_symbol_name + "_" + tf,
                    fromdate=fromdate,
                    todate=todate,
                    sessionstart=sessionstart,
                    sessionend=sessionend,
                    timeframe=timeframe,
                    compression=compression
                )
                cerebro.adddata(data)

            resample = kwargs["resample"]
            if resample:
                _timeframe, _compression, _total_minutes = bt.utils.datehelper.parse_timeframe(resample)
                # cerebro.resampledata(
                #     data,
                #     name=f"{_symbol_name}_{resample}",
                #     boundoff=_total_minutes,
                #     bar2edge=False,
                #     adjbartime=True,
                #     rightedge=False,
                #     timeframe=_timeframe,
                #     compression=_compression, )

        symbol_name = symbols[0].replace("/", "_")

        # cerebro.resampledata(data, "1h", timeframe=bt.TimeFrame.Minutes,  compression=60)
        # cerebro.adddata(data)  # Add the data to cerebro
        # Add the strategy
        strat = kwargs["strat"]
        if not strat:
            self.logger.exception("parameter [strat] must be not none.")
            return
        strat = __import__(strat)
        strat_names = [e for e in strat.__dict__.keys() if e.endswith("Strategy")]
        if not strat_names:
            raise Exception(f"策略名称必须以Strategy结尾.")
        strat_clz = strat_names[0]

        opt = kwargs["opt"]
        filter_exp = kwargs["filter_exp"]
        params = kwargs["params"]

        comm = int(kwargs['comm'] * 10000)
        # 补充标识
        cash = int(kwargs['init_cash'])
        start = kwargs["start_date"].replace('-', '')
        end = kwargs["ended_date"].replace('-', '')

        params["size"] = kwargs["size"]
        params["comm"] = comm
        params["leverage"] = leverage
        params["init_cash"] = cash
        params["start_date"] = int(start)
        params["ended_date"] = int(end)
        params["timeframe"] = tf

        clz = strat.__dict__[strat_clz]
        if opt:
            cerebro.optstrategy(
                clz,
                filter_exp=filter_exp,
                **params
            )
        else:
            cerebro.addstrategy(clz, **params)

        out_style = f"{clz.alias[0]}/{symbol_name}/"
        cerebro.addanalyzer(bt.analyzers.SimpleReport, csv=True, out=out_style)
        cerebro.addanalyzer(bt.analyzers.Trades, csv=True, out=out_style)
        cerebro.addanalyzer(bt.analyzers.Orders, csv=True, out=out_style)
        cerebro.addanalyzer(bt.analyzers.NetAssetValue, csv=True, out=out_style)
        cerebro.addanalyzer(bt.analyzers.Holdings, csv=True, out=out_style)
        cerebro.addanalyzer(bt.analyzers.TimeDrawDownHistory, csv=True, out=out_style)
        # cerebro.addwriter(bt.WriterFile, csv=True, out="test.csv", rounding=4)
        cerebro.run()  # And run it
        if kwargs["plot"]:
            cerebro.plot()

        # save plot
        fdir = "."
        strats = [clz.alias[0]]

        for strategy in cerebro.runstrats:
            params = bt.Analyzer.get_params1(strat=strategy[0])
            name = "_".join(map(lambda x: str(x), params.values()))
            try:
                if "benchmark" in kwargs and kwargs["benchmark"]:
                    ShowPlot(fdir, symbol_name, kwargs["benchmark"], strats, show_fig=False).show(name=name, figsize=(30, 15))
            except Exception as e:
                self.logger.error(e)
                self.logger.warning("Plot benchmark error.")
            ShowPlot(fdir, symbol_name, kwargs["benchmark"], strats, show_fig=False).plot1(name=name, figsize=(30, 15))
            ShowPlot(fdir, symbol_name, kwargs["benchmark"], strats, show_fig=False).plot2(name=name, figsize=(30, 15))